#include <ATen/ATen.h>
#include <torch/extension.h>
#include <torch/torch.h>

#include <iostream>
#include <vector>

#include "norm_sample.h"

// define this enum:
enum bn_type { BN_FWD, BN_BWD };

// this is a global variable
static std::map<std::vector<int64_t>, cudnn_frontend::ExecutionPlan> gbn_plan_cache;

at::Tensor gbn_forward(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias,
                       const at::Tensor& running_mean, const at::Tensor& running_var, const at::Tensor& minibatch_mean,
                       const at::Tensor& minibatch_inv_var, const float momentum, const float epsilon,
                       const int64_t bn_group, const int rank_id, const std::vector<int64_t>& peer_buffers) {
  int64_t N = x.size(0);
  int64_t C = x.size(1);
  int64_t H = x.size(2);
  int64_t W = x.size(3);

  int64_t tensorDims[] = {N, C, H, W};
  int64_t peerDims[] = {bn_group, 4 * C, 1, 1};
  int64_t perChannelDims[] = {1, C, 1, 1};
  int64_t epsilonDims[] = {1, 1, 1, 1};

  // Allocate output tensor
  at::Tensor y = at::empty_like(x);

  std::vector<void*> void_peer_buffers;
  for (int64_t addr : peer_buffers) {
    void_peer_buffers.push_back((void*)addr);
  }

  // we need the peer size for the buffer reset
  size_t peer_size = 1;
  for (size_t i = 0; i < 4; ++i) {
    peer_size *= peerDims[i];
  }

  // sanity check
  assert(bn_group == void_peer_buffers.size());

  // check if plan already exists
  std::vector<int64_t> fv = {(int64_t)BN_FWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF};
  if (gbn_plan_cache.find(fv) == gbn_plan_cache.end()) {
    auto plan = run_batch_norm_forward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF);
    gbn_plan_cache.emplace(fv, std::move(plan));
  }

  // get plan and handle
  auto plan = gbn_plan_cache.find(fv)->second;

  // execute
  execute_batch_norm_forward(plan, x.data_ptr(), y.data_ptr(), scale.data_ptr(), bias.data_ptr(),
                             running_mean.data_ptr(), running_var.data_ptr(), running_mean.data_ptr(),
                             running_var.data_ptr(), minibatch_mean.data_ptr(), minibatch_inv_var.data_ptr(),
                             void_peer_buffers, static_cast<double>(epsilon), static_cast<double>(momentum), peer_size,
                             rank_id);

  return y;
}

std::vector<at::Tensor> gbn_backward(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale,
                                     const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var,
                                     const float epsilon, const int64_t bn_group, const int rank_id,
                                     const std::vector<int64_t>& peer_buffers) {
  int64_t N = x.size(0);
  int64_t C = x.size(1);
  int64_t H = x.size(2);
  int64_t W = x.size(3);

  int64_t tensorDims[] = {N, C, H, W};
  int64_t peerDims[] = {bn_group, 4 * C, 1, 1};
  int64_t perChannelDims[] = {1, C, 1, 1};
  int64_t epsilonDims[] = {1, 1, 1, 1};

  // Allocate output tensor
  // outputs
  at::Tensor x_grad, scale_grad, bias_grad;

  // Allocate outputs
  x_grad = at::empty_like(x);
  scale_grad = at::empty_like(scale);
  bias_grad = at::empty_like(scale);

  std::vector<void*> void_peer_buffers;
  for (int64_t addr : peer_buffers) {
    void_peer_buffers.push_back((void*)addr);
  }

  // we need the peer size for the buffer reset
  size_t peer_size = 1;
  for (size_t i = 0; i < 4; ++i) {
    peer_size *= peerDims[i];
  }

  assert(bn_group == void_peer_buffers.size());

  std::vector<int64_t> fv = {(int64_t)BN_BWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF};
  if (gbn_plan_cache.find(fv) == gbn_plan_cache.end()) {
    auto plan = run_batch_norm_backward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF);
    gbn_plan_cache.emplace(fv, std::move(plan));
  }

  // get plan and handle
  auto plan = gbn_plan_cache.find(fv)->second;

  // execute
  execute_batch_norm_backward(plan, x.data_ptr(), dy.data_ptr(), scale.data_ptr(), minibatch_mean.data_ptr(),
                              minibatch_inv_var.data_ptr(), void_peer_buffers, x_grad.data_ptr(), scale_grad.data_ptr(),
                              bias_grad.data_ptr(), static_cast<double>(epsilon), peer_size, rank_id);

  return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &gbn_forward, "Group batch norm forward", py::call_guard<py::gil_scoped_release>());
  m.def("backward", &gbn_backward, "Group batch backward", py::call_guard<py::gil_scoped_release>());
}
